# This code is part of cqlib.
#
# Copyright (C) 2024 China Telecom Quantum Group, QuantumCTek Co., Ltd.,
# Center for Excellence in Quantum Information and Quantum Physics.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.


"""
This module provides functionality to convert a cqlib Circuit object into
an OpenQASM 2.0 formatted string or file. It includes handling of various
quantum gates, measurements, and optional annotations.

Functions:
- dumps(circuit, header, precision, annotate): Convert a Circuit to an OpenQASM string.
- dump(circuit, filename): Write the OpenQASM string of a Circuit to a file.
"""

import math

from cqlib import _version
from cqlib.circuits.circuit import Circuit
from cqlib.circuits.instruction import Instruction
from cqlib.circuits.parameter import Parameter
from cqlib.exceptions import CqlibError

I_DURATION = 60


def dumps(
        circuit: Circuit,
        header: str = None,
        precision: int = 10,
        annotate: bool = False,
) -> str:
    """
    Convert a cqlib Circuit object to an OpenQASM 2.0 formatted string.

    Args:
        circuit (Circuit): The quantum circuit to convert.
        header (str, optional): Additional header information to include as comments.
        precision (int, optional): The number of decimal places to round parameters to.
        annotate (bool, optional): If True, include original instructions as comments.

    Returns:
        str: The OpenQASM 2.0 representation of the circuit.
    """
    qubits = circuit.qubits
    header = f"// Generated by Cqlib v{_version.__version__}\n" \
             f"{('//' + header) if header else ''}\n" \
             "OPENQASM 2.0;\n" \
             'include "qelib1.inc";\n'

    qreg_def = f'qreg q[{len(qubits)}];\n'
    if annotate:
        qreg_def = f'// Qubits: {",".join(map(str, qubits))}\n' + qreg_def

    measure_qubits = []
    operations = []

    for ins in circuit.instruction_sequence:
        if annotate:
            operations.append(f'// {str(ins)}')

        qs = [qubits.index(q) for q in ins.qubits]
        ps = []
        for p in ins.instruction.params:
            if isinstance(p, Parameter):
                raise CqlibError('OpenQASM2.0 do not support parameter')
            ps.append(round(p, precision))

        if ins.instruction.name == 'M':
            measure_qubits.extend(ins.qubits)
            operations.append(f'measure q[{qs[0]}] '
                              f'-> c[{len(measure_qubits) - 1}];')
            continue
        if item := translate(ins.instruction, qs, ps):
            operations.append(item)

    creg_def = f'creg c[{len(measure_qubits)}];\n'
    if annotate:
        creg_def = f'// Classical register, obtained based on the measured qubits.\n' \
                   f'// {",".join(map(str, measure_qubits))}\n' \
                   f'creg c[{len(measure_qubits)}];\n' + creg_def

    return '\n'.join([
        header,
        qreg_def,
        creg_def,
        *operations,
    ])


def dump(
        circuit: Circuit,
        filename: str,
        precision: int = 10,
        annotate: bool = False
) -> None:
    """
    Write the OpenQASM 2.0 representation of a cqlib Circuit to a file.

    Args:
        circuit (Circuit): The quantum circuit to convert and write.
        filename (str): The path to the file where the OpenQASM code will be written.
        precision (int, optional): The number of decimal places to round parameters to.
        annotate (bool, optional): If True, include original instructions as comments.
    """
    qasm = dumps(circuit, precision=precision, annotate=annotate)
    with open(filename, 'w', encoding='utf-8') as fp:
        fp.write(qasm)


def translate(
        instruction: Instruction,
        qubits: list[int],
        params: list[int | float]
) -> str:
    """
    Translate a quantum instruction into its corresponding OpenQASM 2.0 representation.

     Args:
        instruction (Instruction): The quantum gate instruction to be translated. It must
            have the following attributes:
        qubits (list[int]): A list of qubit indices that the instruction is applied to.
            The order of qubits should match the gate's requirements (e.g., control and
            target qubits for two-qubit gates).
        params (list[int | float]): A list of numerical parameters associated with the
            instruction. These parameters are used for parameterized gates such as rotations
            (e.g., RX, RY, RZ) and must be provided in the order expected by the gate.
    """
    gate = instruction.name.lower()
    item = ''
    match instruction.name:
        case 'TD':
            item = f'tdg q[{qubits[0]}];'
        case 'SD':
            item = f'sdg q[{qubits[0]}];'
        case 'H' | 'X' | 'Y' | 'Z' | 'S' | 'T':
            item = f'{gate} q[{qubits[0]}];'
        case 'RX' | 'RY' | 'RZ':
            item = f'{gate}({params[0]}) q[{qubits[0]}];'
        case 'CX' | 'CY' | 'CZ':
            item = f'{gate} q[{qubits[0]}], q[{qubits[1]}];'
        case 'CCX':
            item = f'{gate} q[{qubits[0]}], q[{qubits[1]}], q[{qubits[2]}];'
        case 'U':
            item = f'u3({",".join(map(str, params))}) q[{qubits[0]}];'
        case 'X2P':
            item = f'sx q[{qubits[0]}];'
        case 'X2M':
            item = f'sxdg q[{qubits[0]}];'
        case 'Y2P':
            item = f'ry(pi/2) q[{qubits[0]}];'
        case 'Y2M':
            item = f'ry(-pi/2) q[{qubits[0]}];'
        case 'XY':
            item = f'rz(pi/2 - {params[0]}) q[{qubits[0]}];\n' \
                   f'y q[{qubits[0]}];\n' \
                   f'rz({params[0]} - pi/2) q[{qubits[0]}];'
        case 'XY2P':
            item = f'rz(pi/2 - {params[0]}) q[{qubits[0]}];\n' \
                   f'ry(pi/2) q[{qubits[0]}];\n' \
                   f'rz({params[0]} - pi/2) q[{qubits[0]}];'
        case 'XY2M':
            item = f'rz(-pi/2 - {params[0]}) q[{qubits[0]}];\n' \
                   f'ry(pi/2) q[{qubits[0]}];\n' \
                   f'rz({params[0]} + pi/2) q[{qubits[0]}];'
        case 'RXY':
            item = f"u3({params[1]},{params[0]} - pi/2,pi/2 - {params[0]}) q[{qubits[0]}];"
        case 'CRX':
            item = f's q[{qubits[1]}];\n' \
                   f'cx q[{qubits[0]}], q[{qubits[1]}];\n' \
                   f'ry(-{params[0]} / 2) q[{qubits[1]}];\n' \
                   f'cx q[{qubits[0]}], q[{qubits[1]}];\n' \
                   f'ry({params[0]} / 2) q[{qubits[1]}];\n' \
                   f'sdg q[{qubits[1]}];'
        case 'CRY':
            item = f'ry({params[0]} / 2) q[{qubits[1]}];\n' \
                   f'cx q[{qubits[0]}], q[{qubits[1]}];\n' \
                   f'ry(-{params[0]} / 2) q[{qubits[1]}];\n' \
                   f'cx q[{qubits[0]}], q[{qubits[1]}];'
        case 'CRZ':
            item = f'rz({params[0]} / 2) q[{qubits[1]}];\n' \
                   f'cx q[{qubits[0]}], q[{qubits[1]}];\n' \
                   f'rz(-{params[0]} / 2) q[{qubits[1]}];\n' \
                   f'cx q[{qubits[0]}], q[{qubits[1]}];'
        case 'SWAP':
            item = f'cx q[{qubits[0]}], q[{qubits[1]}];\n' \
                   f'cx q[{qubits[1]}], q[{qubits[0]}];\n' \
                   f'cx q[{qubits[0]}], q[{qubits[1]}];'
        case 'B':
            item = f'barrier {",".join([f"q[{q}]" for q in qubits])};'
        case 'I':
            t = instruction.params[0]
            for _ in range(math.ceil(t // I_DURATION)):
                item = f'id q[{qubits[0]}];'
    return item
